#!/bin/bash
#SBATCH --job-name={{ job_name }}
#SBATCH --nodes={{ total_nodes }}
#SBATCH --ntasks={{ total_nodes }}
#SBATCH --ntasks-per-node=1
#SBATCH --account={{ account }}
#SBATCH --time={{ time_limit }}
#SBATCH --output=logs/%j_{{ agg_workers }}A_{{ timestamp }}/log.out
#SBATCH --error=logs/%j_{{ agg_workers }}A_{{ timestamp }}/log.err
#SBATCH --partition={{ partition }}

# Constants
set -x
AGG_NODES={{ agg_nodes }}
AGG_WORKERS={{ agg_workers }}
TOTAL_NODES={{ total_nodes }}
GPUS_PER_NODE={{ gpus_per_node }}
TOTAL_GPUS=$((AGG_NODES * GPUS_PER_NODE))
PREFILL_GPUS=0
DECODE_GPUS=$TOTAL_GPUS
AGG_NODES_PER_WORKER=$((AGG_NODES / AGG_WORKERS))
LOG_DIR="${SLURM_SUBMIT_DIR}/logs/${SLURM_JOB_ID}_{{ agg_workers }}A_{{ timestamp }}"
SCRIPT_DIR="${SLURM_SUBMIT_DIR}/scripts"
OUTPUT_DIR="${SLURM_SUBMIT_DIR}/outputs"
MODEL_DIR="{{ model_dir }}"
CONFIG_DIR="{{ config_dir }}"
CONTAINER_IMAGE="{{ container_image }}"
NETWORK_INTERFACE="{{ network_interface }}"
GPU_TYPE="{{ gpu_type | default('h100') }}"
set +x

{% raw %}

mkdir -p "${OUTPUT_DIR}" "${LOG_DIR}"

nodes=($(scontrol show hostnames $SLURM_NODELIST))
if [ ${#nodes[@]} -ne $TOTAL_NODES ]; then
    echo "Error: Expected $TOTAL_NODES nodes but got ${#nodes[@]} nodes"
    exit 1
fi

# Print node information
for i in "${!nodes[@]}"; do
    echo "Node $i: ${nodes[$i]}"
done

{% endraw %}
{% if enable_multiple_frontends %}
{% raw %}
# Multiple frontend architecture
# Node 0: nginx + aggregated worker shard
# Node 1: NATS/ETCD + first frontend
# Node 2+: aggregated workers + optional additional frontends

NGINX_NODE=${nodes[0]}
MASTER_NODE=${nodes[1]}
MASTER_IP=$(srun --nodes=1 --ntasks=1 --nodelist=$MASTER_NODE ip addr show $NETWORK_INTERFACE | grep 'inet ' | awk '{print $2}' | cut -d'/' -f1)
if [ -z "$MASTER_IP" ]; then
    echo "Error: Could not retrieve IP address for master host $MASTER_NODE on interface $NETWORK_INTERFACE"
    exit 1
fi
echo "Master IP address (node 1): $MASTER_IP"
echo "Nginx node (node 0): $NGINX_NODE"

# Generate frontend IP list for nginx config
frontend_hosts=()
frontend_ips=()
# Node 1 always has a frontend (with NATS/ETCD)
frontend_hosts+=("$MASTER_NODE")
frontend_ips+=("$MASTER_IP")

# Add additional frontends based on num_additional_frontends
{% endraw %}ADDITIONAL_FRONTENDS={{ num_additional_frontends }}{% raw %}
if [ "$ADDITIONAL_FRONTENDS" -gt 0 ]; then
    # Calculate which nodes get additional frontends
    # We have AGG_NODES aggregated worker nodes, distribute additional frontends across them
    nodes_per_frontend=$(( (AGG_NODES - 1 + ADDITIONAL_FRONTENDS - 1) / ADDITIONAL_FRONTENDS ))  # ceil division
    frontend_node_idx=2  # Start from node 2 (node 1 already has frontend)

    for i in $(seq 1 $ADDITIONAL_FRONTENDS); do
        if [ $frontend_node_idx -lt $TOTAL_NODES ]; then
            node_name=${nodes[$frontend_node_idx]}
            node_ip=$(srun --nodes=1 --ntasks=1 --nodelist=$node_name ip addr show $NETWORK_INTERFACE | grep 'inet ' | awk '{print $2}' | cut -d'/' -f1)
            frontend_hosts+=("$node_name")
            frontend_ips+=("$node_ip")
            echo "Additional frontend $i on node $frontend_node_idx: $node_name ($node_ip)"
            frontend_node_idx=$((frontend_node_idx + nodes_per_frontend))
        fi
    done
fi

echo "Frontend hosts: ${frontend_hosts[@]}"
echo "Frontend IPs: ${frontend_ips[@]}"

# Generate nginx configuration
# Build a Python list literal of frontend hosts from the bash array
FRONTEND_LIST=$(printf "'%s'," "${frontend_ips[@]}")
FRONTEND_LIST="[${FRONTEND_LIST%,}]"
export FRONTEND_LIST SCRIPT_DIR LOG_DIR
python3 - <<'PY'
import os
from jinja2 import Template

template_path = os.path.join(os.environ['SCRIPT_DIR'], 'nginx.conf.j2')
output_path = os.path.join(os.environ['LOG_DIR'], 'nginx.conf')

with open(template_path, 'r') as f:
    tmpl = Template(f.read())

frontend_hosts = eval(os.environ['FRONTEND_LIST'])
config = tmpl.render(frontend_hosts=frontend_hosts)

with open(output_path, 'w') as f:
    f.write(config)
PY

{% endraw %}
{% else %}
{% raw %}
# Traditional architecture - first aggregated worker node handles everything
MASTER_IP=$(srun --nodes=1 --ntasks=1 --nodelist=${nodes[0]} ip addr show $NETWORK_INTERFACE | grep 'inet ' | awk '{print $2}' | cut -d'/' -f1)
if [ -z "$MASTER_IP" ]; then
    echo "Error: Could not retrieve IP address for master host ${nodes[0]} on interface $NETWORK_INTERFACE"
    exit 1
fi
echo "Master IP address: $MASTER_IP"
{% endraw %}
{% endif %}
{% raw %}

# Compute leader nodes for each aggregated worker
{% endraw %}
{% if enable_multiple_frontends %}
{% raw %}
# With multiple frontends: keep offset 0; nginx coexists on node 0
WORKER_NODE_OFFSET=0
{% endraw %}
{% else %}
{% raw %}
# Traditional: workers start from node 0
WORKER_NODE_OFFSET=0
{% endraw %}
{% endif %}
{% raw %}

agg_leaders=()
for i in $(seq 0 $((AGG_WORKERS - 1))); do
    leader_idx=$((WORKER_NODE_OFFSET + i * AGG_NODES_PER_WORKER))
    agg_leaders[$i]=$leader_idx
done

echo "Aggregated worker leaders: ${agg_leaders[@]}"

# Prepare enroot arguments to pass to srun commands
ENROOT_ARGS="\
    --container-image=${CONTAINER_IMAGE} \
    --no-container-entrypoint \
    --no-container-mount-home \
    --container-mounts=${MODEL_DIR}:/model/,${CONFIG_DIR}:/configs/,${SCRIPT_DIR}:/scripts/,${OUTPUT_DIR}:/outputs/,${LOG_DIR}:/logs/ \
"

# Build common worker arguments
{% endraw %}
SCRIPT_VARIANT="{{ script_variant | default('default') }}"
{% raw %}
WORKER_ARGS="--gpu_type ${GPU_TYPE} --script-variant ${SCRIPT_VARIANT} --gpus_per_node ${GPUS_PER_NODE} --master_ip ${MASTER_IP}"
{% endraw %}
{% if enable_multiple_frontends %}
{% raw %}
# Add multiple frontends flag for worker setup
WORKER_ARGS="$WORKER_ARGS --multiple-frontends-enabled"
{% endraw %}
{% endif %}
{% if run_in_ci %}
{% raw %}
# Add CI mode flag for worker setup
WORKER_ARGS="$WORKER_ARGS --run-in-ci"
{% endraw %}
{% endif %}
{% raw %}

{% endraw %}
{% if enable_multiple_frontends %}
{% raw %}
# Launch nginx on node 0
echo "Launching nginx on ${NGINX_NODE}"
cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$NGINX_NODE --output=${LOG_DIR}/${NGINX_NODE}_nginx.out --error=${LOG_DIR}/${NGINX_NODE}_nginx.err python /scripts/worker_setup.py --worker_type nginx --nginx_config /logs/nginx.conf ${WORKER_ARGS}"
echo "$cmd"
$cmd &

# Launch frontend on master node (node 1) - this will also start NATS/ETCD
echo "Launching frontend + NATS/ETCD on master node ${MASTER_NODE}"
cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$MASTER_NODE --output=${LOG_DIR}/${MASTER_NODE}_frontend_0.out --error=${LOG_DIR}/${MASTER_NODE}_frontend.err python /scripts/worker_setup.py --worker_type frontend --worker_idx 0 ${WORKER_ARGS}"
echo "$cmd"
$cmd &

# Launch additional frontends on designated nodes
if [ "$ADDITIONAL_FRONTENDS" -gt 0 ]; then
    frontend_idx=1  # Start from 1 since node 1 is frontend 0
    nodes_per_frontend=$(( (TOTAL_NODES - 2 + ADDITIONAL_FRONTENDS - 1) / ADDITIONAL_FRONTENDS ))
    frontend_node_idx=2

    for i in $(seq 1 $ADDITIONAL_FRONTENDS); do
        if [ $frontend_node_idx -lt $TOTAL_NODES ]; then
            node=${nodes[$frontend_node_idx]}
            echo "Launching additional frontend $frontend_idx on node $frontend_node_idx: $node"
            cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node --output=${LOG_DIR}/${node}_frontend_${frontend_idx}.out --error=${LOG_DIR}/${node}_frontend_${frontend_idx}.err python /scripts/worker_setup.py --worker_type frontend --worker_idx ${frontend_idx} ${WORKER_ARGS}"
            echo "$cmd"
            $cmd &
            frontend_idx=$((frontend_idx + 1))
            frontend_node_idx=$((frontend_node_idx + nodes_per_frontend))
        fi
    done
fi
{% endraw %}
{% else %}
{% raw %}
# Traditional: first aggregated worker node also runs frontend + NATS/ETCD
# This is handled in setup_aggregated_worker when worker_idx=0 and local_rank=0
{% endraw %}
{% endif %}
{% raw %}

# Launch aggregated workers
for worker_idx in $(seq 0 $((AGG_WORKERS - 1))); do
    leader_idx=${agg_leaders[$worker_idx]}
    leader_node=${nodes[$leader_idx]}

    # Get leader IP for this worker group
    LEADER_IP=$(srun --nodes=1 --ntasks=1 --nodelist=$leader_node ip addr show $NETWORK_INTERFACE | grep 'inet ' | awk '{print $2}' | cut -d'/' -f1)
    echo "Aggregated worker $worker_idx leader: $leader_node ($LEADER_IP)"

    # Launch all nodes for this worker
    for node_idx in $(seq 0 $((AGG_NODES_PER_WORKER - 1))); do
        global_node_idx=$((leader_idx + node_idx))
        node=${nodes[$global_node_idx]}
        local_rank=$node_idx

        echo "Launching aggregated worker $worker_idx, node $global_node_idx (local_rank $local_rank): $node"
{% endraw %}
{% if enable_config_dump %}
{% raw %}
        CONFIG_DUMP_ARG="--dump-config-path /logs/${node}_config.json"
{% endraw %}
{% else %}
{% raw %}
        CONFIG_DUMP_ARG=""
{% endraw %}
{% endif %}
{% raw %}
        cmd="srun --overlap $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node --output=${LOG_DIR}/${node}_agg_w${worker_idx}.out --error=${LOG_DIR}/${node}_agg_w${worker_idx}.err python /scripts/worker_setup.py --leader_ip ${LEADER_IP} --worker_idx ${worker_idx} --local_rank ${local_rank} --nodes_per_worker ${AGG_NODES_PER_WORKER} --worker_type aggregated --gpu_utilization_log /logs/${node}_agg_w${worker_idx}_gpu_utilization.log ${CONFIG_DUMP_ARG} ${WORKER_ARGS}"
        echo "$cmd"
        $cmd &
    done
done

echo ""
{% endraw %}
{% if enable_multiple_frontends %}
{% raw %}
echo "Frontend available at: http://${NGINX_NODE}:8000"
echo "To connect to the nginx node:"
echo "srun $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${NGINX_NODE} --overlap --pty bash"
echo "To connect to the master node (NATS/ETCD):"
echo "srun $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${MASTER_NODE} --overlap --pty bash"
{% endraw %}
{% else %}
{% raw %}
echo "To connect to the master node:"
echo "srun $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${nodes[0]} --overlap --pty bash"
{% endraw %}
{% endif %}
{% raw %}

echo ""
echo "Make sure to cancel the job at the end:"
echo "scancel $SLURM_JOB_ID"

# Instead of waiting for all tasks to complete, wait for profile.sh to complete and then exit.

{% endraw %}

PROFILER_TYPE={{ profiler_type }}
PROFILER_ARGS="{{ profiler_arg }}"

{% if do_profile %}
{% raw %}
srun --nodes=1 --ntasks=1 $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${nodes[0]} --output=${LOG_DIR}/profile.out --error=${LOG_DIR}/profile.err --overlap bash /scripts/${PROFILER_TYPE}/bench.sh 0 $AGG_WORKERS $PREFILL_GPUS $DECODE_GPUS $TOTAL_GPUS ${PROFILER_ARGS} &
{% endraw %}
{% endif %}

{% raw %}
wait -n
first_exit_code=$?
echo "Script finished at $(date) with exit code ${first_exit_code}"
exit $first_exit_code
{% endraw %}

